
import argparse
import numpy as np
import numpy.lib as npl
import torch
from torch import nn, optim, autograd
from torch.autograd import Variable
import torch.utils.data as Data

from load import get_years, get_sectors, YEARS
from utils import *
from densratio import densratio
#%%

parser = argparse.ArgumentParser(description='Colored MNIST')
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.001)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--n_restarts', type=int, default=1)
parser.add_argument('--penalty_anneal_iters', type=int, default=1)
parser.add_argument('--irm_penalty_weight', type=float, default=1000000.0)
parser.add_argument('--rex_penalty_weight', type=float, default=10000000000.0)
parser.add_argument('--steps', type=int, default=300)
parser.add_argument('--plot', action='store_true')
parser.add_argument('--save', type=str, default='')
parser.add_argument('--train_envs', type=str, default='2014,2015,2016')
parser.add_argument('--test_envs', type=str, default='')
parser.add_argument('--BATCH_SIZE', type=str, default=100)
flags = parser.parse_args()
#%%
train_env_ids = [int(s.strip()) for s in flags.train_envs.split(',')]
if flags.test_envs:
    test_env_ids = [int(s.strip()) for s in flags.test_envs.split(',')]
else:
    test_env_ids = npl.setxor1d(YEARS, train_env_ids)


print('Flags:')
for k,v in sorted(vars(flags).items()):
    print("    {}: {}".format(k, v))


def whiten(x):
    with torch.no_grad():
        x -= x.mean(dim=0)
        x /= x.std(dim=0)
    return x

def whiten1(x):
    with torch.no_grad():
        x -= x.mean()
        x /= x.std()
    return x

#def mean_nll(logits, y):
#    return nn.functional.binary_cross_entropy_with_logits(logits, y)

def mean_nll(logits,y):
    loss=nn.MSELoss(reduction='mean')
    return loss(logits, y)


# def mean_accuracy(logits, y):
#     preds = (logits > 0.).float()
#     return ((preds - y).abs() < 1e-2).float().mean()

def mean_accuracy(logits, y):
    preds = (logits > 0.).float()
    return ((preds - y).abs()).float().mean()


def env_irm_penalty(logits, y):
    scale = torch.tensor(1.).cpu().requires_grad_()
    loss = mean_nll(logits * scale, y)
    grad = autograd.grad(loss, [scale], create_graph=True)[0]
    return torch.mean(grad**2)

def get_rex_penalty(train_envs):
    losses = torch.stack([e['nll'] for e in train_envs])
    penalty = torch.var(losses)
    return penalty


class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.input_size = input_size
        lin1 = nn.Linear(input_size, flags.hidden_dim)
        lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
        lin3 = nn.Linear(flags.hidden_dim, 1)
        # for lin in [lin1, lin2, lin3]:
        #     nn.init.xavier_uniform_(lin.weight)
        #     nn.init.zeros_(lin.bias)
        self._main = nn.Sequential(
            lin1, nn.ReLU(True), #nn.Tanh(), #nn.ReLU(True),
            nn.Dropout(),
            lin2, nn.ReLU(True), #nn.Tanh(), #nn.ReLU(True),
            nn.Dropout(),
            lin3)
        
    def forward(self, x):
        x = x.view(x.shape[0], self.input_size)
        out = self._main(x)
        return out

#%%

def Q_CRIC(env):
    Q_num=0
    Q_phi=0
    # y_e_E=[]
    # q_ee=[]
    # dr_co_list=[]
    # err=torch.zeros(len(env[]))
    assert type(env)==list
    for i in range(len(env)):
        #x, _ = env[i]
        train_e=env[i]
        y = train_e['logits']
        x = train_e['images']
        q_e=torch.mean(y)
        # q_ee.append(q_e)
        # for j in env[-i]:
        for j in range(len(env)):
            if j == i:
                continue
            #x_t, _ = env[j]
            train_t=env[j]
            y_t=train_t['logits']
            x_t=train_t['images']
            dr_co=densratio(x.numpy(),x_t.numpy())
            #dr_co=densratio(x[:,0].numpy(),x_t[:,0].numpy(),verbose=False)  #calculate the density ratio of x and xt
            #w_co=torch.from_numpy(dr_co.compute_density_ratio(x_t[:,0].numpy()))  #calculate the likelihood ratio of x over xt
            #w_x_t=x_t
            w_co=torch.from_numpy(dr_co.compute_density_ratio(x_t.numpy())) 
            y_t = w_co*train_t['logits']
            q_e_co=torch.mean(y_t)  #calculate the weighted q
        
            Q_num=Q_num+(q_e_co-q_e)**2
            Q_num=Q_num.detach()
            
            # dr_co1=densratio(x[:,1].numpy(),x_t[:,1].numpy(),verbose=False)  #calculate the density ratio of x and xt
            # w_co1=torch.from_numpy(dr_co1.compute_density_ratio(x_t[:,1].numpy()))  #calculate the likelihood ratio of x over xt
            # w_x_t1=x_t
            # w_x_t1[:,1]=w_co1*x_t[:,1]
            # y_t1 = (w_x_t1 @ solution).squeeze()
            # q_e_co1=torch.mean(y_t1)
            
            # Q_phi=Q_phi+(q_e_co1-q_e)**2
            # Q_phi=Q_phi.detach()
        # dr_co_list.append(dr_co)
        # y_e_E.append(torch.mean(y_t)) 
    # y_E=sum(y_e_E)/len(env)
    # Q_den=sum(y_e_E-y_E)**2
    #Q_phi=Q_num#/Q_den
    return Q_num
# class MLP(torch.nn.Module):
#     def __init__(self, width_vec: list = None):
#         super(MLP, self).__init__()
#         self.width_vec= width_vec

#         modules = []
#         if width_vec is None:
#             width_vec = [256, 256]

#         # Network
#         for i in range(len(width_vec) - 1):
#             modules.append(
#                 nn.Sequential(
#                     nn.Linear(width_vec[i],width_vec[i+1]),
#                     nn.ReLU()))

#         self.net = nn.Sequential(*modules,
#                                  nn.Linear(width_vec[-1],1))

    #def forward(self, input):
    #    output = self.net(input)
     #   return  output





#%%

final_train_accs = []
final_test_accs = []
logs = []
for restart in range(flags.n_restarts):
    print("Restart", restart)
#%%
    train_envs = get_years(train_env_ids)
    test_envs = get_years(test_env_ids)
    # preprocess
    for e in train_envs + test_envs:
        e['images'] = whiten(e['images'])
        e['labels'] = whiten1(e['labels'])
    print_env_info(train_envs, test_envs)
    

    # init
    logger = Logger()
    #mlp = MLP(train_envs[0]['images'].shape[1]).cpu()
    mlp = MLP(train_envs[0]['images'].shape[1]).cpu()
    optimizer = optim.Adam(mlp.parameters(), lr=flags.lr)
    loader = Data.DataLoader(
    dataset=train_envs + test_envs, 
    batch_size=flags.BATCH_SIZE, 
    shuffle=False, num_workers=0,)

    pretty_print('step', 'train nll', 'train acc', 'irm penalty', 'rex penalty', 'test acc')
#%%
    for step in range(flags.steps):
        for env in train_envs + test_envs:
            env['logits'] = mlp(env['images'])
            env['nll'] = mean_nll(env['logits'], env['labels'])
            env['acc'] = mean_accuracy(env['logits'], env['labels'])
            env['penalty'] = env_irm_penalty(env['logits'], env['labels'])

        train_nll = torch.stack([e['nll'] for e in train_envs]).mean()
        train_acc = torch.stack([e['acc'] for e in train_envs]).mean()
        irm_penalty = torch.stack([e['penalty'] for e in train_envs]).mean()
        rex_penalty = get_rex_penalty(train_envs)

        weight_norm = torch.tensor(0.).cpu()
        for w in mlp.parameters():
            weight_norm += w.norm().pow(2)

        loss = train_nll.clone()
        loss += flags.l2_regularizer_weight * weight_norm

        if flags.irm_penalty_weight:
            if step >= flags.penalty_anneal_iters:
                loss /= flags.irm_penalty_weight
            loss += irm_penalty

        elif flags.rex_penalty_weight:
            if step >= flags.penalty_anneal_iters:
                loss /= flags.rex_penalty_weight
            loss += rex_penalty

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.log('train_nll', train_nll)
        logger.log('train_acc', train_acc)
        logger.log('irm_penalty', irm_penalty)
        logger.log('rex_penalty', rex_penalty)
        logger.log('test_acc', [e['acc'] for e in test_envs])
        logger.log('losses', [e['nll'] for e in train_envs])

        if step % 1000 == 0:
            print_stats(step, logger)
#%%
    final_train_accs.append(np.mean(logger['train_acc'][-50:]))
    final_test_accs.append(np.mean(logger['test_acc'][-50:]))
    print('Final train acc (mean/std across restarts so far):')
    print(np.mean(final_train_accs), np.std(final_train_accs))
    print('Final test acc (mean/std across restarts so far):')
    print(np.mean(final_test_accs), np.std(final_test_accs))

    logs.append(logger)

    if flags.plot:
        plot(logger)
#%%
#if flags.save:
#    save(logs, 'results/%s_%s_%s' % (flags.save, ','.join([str(e) for e in train_env_ids]), ','.join([str(e) for e in test_env_ids])))

Q_num=Q_CRIC(train_envs)
Q_phi_t=Q_CRIC(test_envs)



